Parte 10: Aprendizagem Federada com Agregação de Gradientes Criptografada

Nas últimas seções, temos aprendido sobre computação criptografada através da construção de vários programas simples. Nesta seção, vamos voltar à demonstração de Aprendizagem Federada da Parte 4, onde tivemos um "agregador de confiança" que foi responsável pela média das atualizações do modelo de vários vorkers.

Vamos agora usar nossas novas ferramentas de computação criptografada para remover esse agregador confiável, pois ele é menos do que ideal, pois pressupõe que podemos encontrar alguém confiável o suficiente para ter acesso a essas informações confidenciais. Este nem sempre é o caso.

Assim, nesta parte do tutorial, mostraremos como se pode usar o SMPC para realizar uma agregação segura, de modo a não precisarmos de um "agregador de confiança".

Autores:

Tradução:

Seção 1: Aprendizagem Federada Usual

Primeiro, aqui está um código que realiza o clássico aprendizado federado no Boston Housing Dataset. Esta seção do código é dividida em várias partes.

Configuração


In [ ]:
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

class Parser:
    """Parameters for training"""
    def __init__(self):
        self.epochs = 10
        self.lr = 0.001
        self.test_batch_size = 8
        self.batch_size = 8
        self.log_interval = 10
        self.seed = 1
    
args = Parser()

torch.manual_seed(args.seed)
kwargs = {}

Carregando o Conjunto de Dados


In [ ]:
with open('../data/BostonHousing/boston_housing.pickle','rb') as f:
    ((X, y), (X_test, y_test)) = pickle.load(f)

X = torch.from_numpy(X).float()
y = torch.from_numpy(y).float()
X_test = torch.from_numpy(X_test).float()
y_test = torch.from_numpy(y_test).float()
# preprocessing
mean = X.mean(0, keepdim=True)
dev = X.std(0, keepdim=True)
mean[:, 3] = 0. # the feature at column 3 is binary,
dev[:, 3] = 1.  # so we don't standardize it
X = (X - mean) / dev
X_test = (X_test - mean) / dev
train = TensorDataset(X, y)
test = TensorDataset(X_test, y_test)
train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = DataLoader(test, batch_size=args.test_batch_size, shuffle=True, **kwargs)

Estrutura da Rede Neural## Neural Network Structure


In [ ]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(13, 32)
        self.fc2 = nn.Linear(32, 24)
        self.fc3 = nn.Linear(24, 1)

    def forward(self, x):
        x = x.view(-1, 13)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = Net()
optimizer = optim.SGD(model.parameters(), lr=args.lr)

Criando Gancho do PyTorch


In [ ]:
import syft as sy

hook = sy.TorchHook(torch)
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
james = sy.VirtualWorker(hook, id="james")

compute_nodes = [bob, alice]

Envie dados para os workers
Normalmente eles já o teriam, isto é apenas para fins de demonstração que nós o enviamos manualmente.


In [ ]:
train_distributed_dataset = []

for batch_idx, (data,target) in enumerate(train_loader):
    data = data.send(compute_nodes[batch_idx % len(compute_nodes)])
    target = target.send(compute_nodes[batch_idx % len(compute_nodes)])
    train_distributed_dataset.append((data, target))

Função de Treinamento


In [ ]:
def train(epoch):
    model.train()
    for batch_idx, (data,target) in enumerate(train_distributed_dataset):
        worker = data.location
        model.send(worker)

        optimizer.zero_grad()
        # update the model
        pred = model(data)
        loss = F.mse_loss(pred.view(-1), target)
        loss.backward()
        optimizer.step()
        model.get()
            
        if batch_idx % args.log_interval == 0:
            loss = loss.get()
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * data.shape[0], len(train_loader),
                       100. * batch_idx / len(train_loader), loss.item()))

Função de Teste


In [ ]:
def test():
    model.eval()
    test_loss = 0
    for data, target in test_loader:
        output = model(data)
        test_loss += F.mse_loss(output.view(-1), target, reduction='sum').item() # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}\n'.format(test_loss))

Treinando o Modelo


In [ ]:
import time

In [ ]:
t = time.time()

for epoch in range(1, args.epochs + 1):
    train(epoch)

    
total_time = time.time() - t
print('Total', round(total_time, 2), 's')

Calculando a Performance


In [ ]:
test()

Seção 2: Adicionando Agregação Criptografada

Agora vamos modificar um pouco este exemplo para agregar gradientes usando criptografia. A diferença principal é uma ou duas linhas de código na função train(), que nós vamos destacar. Por enquanto, vamos reprocessar nossos dados e inicializar um modelo para bob e alice.


In [ ]:
remote_dataset = (list(),list())

train_distributed_dataset = []

for batch_idx, (data,target) in enumerate(train_loader):
    data = data.send(compute_nodes[batch_idx % len(compute_nodes)])
    target = target.send(compute_nodes[batch_idx % len(compute_nodes)])
    remote_dataset[batch_idx % len(compute_nodes)].append((data, target))

def update(data, target, model, optimizer):
    model.send(data.location)
    optimizer.zero_grad()
    pred = model(data)
    loss = F.mse_loss(pred.view(-1), target)
    loss.backward()
    optimizer.step()
    return model

bobs_model = Net()
alices_model = Net()

bobs_optimizer = optim.SGD(bobs_model.parameters(), lr=args.lr)
alices_optimizer = optim.SGD(alices_model.parameters(), lr=args.lr)

models = [bobs_model, alices_model]
params = [list(bobs_model.parameters()), list(alices_model.parameters())]
optimizers = [bobs_optimizer, alices_optimizer]

Construindo nossa Lógica de Treinamento

A única diferença real está dentro deste método de treino. Vamos vê-lo passo-a-passo.

Parte A: Treino:


In [ ]:
# this is selecting which batch to train on
data_index = 0
# update remote models
# we could iterate this multiple times before proceeding, but we're only iterating once per worker here
for remote_index in range(len(compute_nodes)):
    data, target = remote_dataset[remote_index][data_index]
    models[remote_index] = update(data, target, models[remote_index], optimizers[remote_index])

Parte B: Agregação Criptografada


In [ ]:
# create a list where we'll deposit our encrypted model average
new_params = list()

In [ ]:
# iterate through each parameter
for param_i in range(len(params[0])):

    # for each worker
    spdz_params = list()
    for remote_index in range(len(compute_nodes)):
        
        # select the identical parameter from each worker and copy it
        copy_of_parameter = params[remote_index][param_i].copy()
        
        # since SMPC can only work with integers (not floats), we need
        # to use Integers to store decimal information. In other words,
        # we need to use "Fixed Precision" encoding.
        fixed_precision_param = copy_of_parameter.fix_precision()
        
        # now we encrypt it on the remote machine. Note that 
        # fixed_precision_param is ALREADY a pointer. Thus, when
        # we call share, it actually encrypts the data that the
        # data is pointing TO. This returns a POINTER to the 
        # MPC secret shared object, which we need to fetch.
        encrypted_param = fixed_precision_param.share(bob, alice, crypto_provider=james)
        
        # now we fetch the pointer to the MPC shared value
        param = encrypted_param.get()
        
        # save the parameter so we can average it with the same parameter
        # from the other workers
        spdz_params.append(param)

    # average params from multiple workers, fetch them to the local machine
    # decrypt and decode (from fixed precision) back into a floating point number
    new_param = (spdz_params[0] + spdz_params[1]).get().float_precision()/2
    
    # save the new averaged parameter
    new_params.append(new_param)

Parte C: Limpeza


In [ ]:
with torch.no_grad():
    for model in params:
        for param in model:
            param *= 0

    for model in models:
        model.get()

    for remote_index in range(len(compute_nodes)):
        for param_index in range(len(params[remote_index])):
            params[remote_index][param_index].set_(new_params[param_index])

Vamos Juntar Tudo!!

E agora que conhecemos cada etapa, nós podemos colocar todas juntas em mesmo ciclo de treinamento.


In [ ]:
def train(epoch):
    for data_index in range(len(remote_dataset[0])-1):
        # update remote models
        for remote_index in range(len(compute_nodes)):
            data, target = remote_dataset[remote_index][data_index]
            models[remote_index] = update(data, target, models[remote_index], optimizers[remote_index])

        # encrypted aggregation
        new_params = list()
        for param_i in range(len(params[0])):
            spdz_params = list()
            for remote_index in range(len(compute_nodes)):
                spdz_params.append(params[remote_index][param_i].copy().fix_precision().share(bob, alice, crypto_provider=james).get())

            new_param = (spdz_params[0] + spdz_params[1]).get().float_precision()/2
            new_params.append(new_param)

        # cleanup
        with torch.no_grad():
            for model in params:
                for param in model:
                    param *= 0

            for model in models:
                model.get()

            for remote_index in range(len(compute_nodes)):
                for param_index in range(len(params[remote_index])):
                    params[remote_index][param_index].set_(new_params[param_index])

In [ ]:
def test():
    models[0].eval()
    test_loss = 0
    for data, target in test_loader:
        output = models[0](data)
        test_loss += F.mse_loss(output.view(-1), target, reduction='sum').item() # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        
    test_loss /= len(test_loader.dataset)
    print('Test set: Average loss: {:.4f}\n'.format(test_loss))

In [ ]:
t = time.time()

for epoch in range(args.epochs):
    print(f"Epoch {epoch + 1}")
    train(epoch)
    test()

    
total_time = time.time() - t
print('Total', round(total_time, 2), 's')

Parabéns!!! - Hora de se juntar a comunidade!

Parabéns por concluir esta etapa do tutorial! Se você gostou e gostaria de se juntar ao movimento em direção à proteção de privacidade, propriedade descentralizada e geração, demanda em cadeia, de dados em IA, você pode fazê-lo das seguintes maneiras!

Dê-nos uma estrela em nosso repo do PySyft no GitHub

A maneira mais fácil de ajudar nossa comunidade é adicionando uma estrela nos nossos repositórios! Isso ajuda a aumentar a conscientização sobre essas ferramentas legais que estamos construindo.

Junte-se ao Slack!

A melhor maneira de manter-se atualizado sobre os últimos avanços é se juntar à nossa comunidade! Você pode fazer isso preenchendo o formulário em http://slack.openmined.org

Contribua com o projeto!

A melhor maneira de contribuir para a nossa comunidade é se tornando um contribuidor do código! A qualquer momento, você pode acessar a página de Issues (problemas) do PySyft no GitHub e filtrar por "Projetos". Isso mostrará todas as etiquetas (tags) na parte superior, com uma visão geral de quais projetos você pode participar! Se você não deseja ingressar em um projeto, mas gostaria de codificar um pouco, também pode procurar mais mini-projetos "independentes" pesquisando problemas no GitHub marcados como "good first issue".

Doar

Se você não tem tempo para contribuir com nossa base de códigos, mas ainda deseja nos apoiar, também pode se tornar um Apoiador em nosso Open Collective. Todas as doações vão para hospedagem na web e outras despesas da comunidade, como hackathons e meetups!

Página do Open Collective do OpenMined


In [ ]: